{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test notebook for the MegaDetector Flask API"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Imports and environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "import os\n",
    "import json\n",
    "import io\n",
    "\n",
    "from requests_toolbelt import MultipartEncoder\n",
    "from requests_toolbelt.multipart import decoder\n",
    "from PIL import Image\n",
    "\n",
    "sample_input_dir = os.path.expanduser('~/git/MegaDetector/test_images/test_images')\n",
    "\n",
    "ip_address = 'localhost'\n",
    "port = '5050'\n",
    "detect_endpoint = 'http://{}:{}/v1/camera-trap/sync/detect'.format(ip_address, port)\n",
    "print(detect_endpoint)\n",
    "\n",
    "max_images_per_call = 8"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Select images to submit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filenames = os.listdir(sample_input_dir)\n",
    "# filenames = [s for s in filenames if (s.lower().endswith('.png') or s.lower().endswith('.jpg'))]\n",
    "filenames = [os.path.join(sample_input_dir,s) for s in filenames]\n",
    "    \n",
    "print('Found {} image(s):'.format(len(filenames)))\n",
    "for fn in filenames:\n",
    "    print(fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Submit images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = {\n",
    "    'min_confidence': 0.15,\n",
    "    'min_rendering_confidence': 0.2,\n",
    "    'render': True,\n",
    "    'key': None\n",
    "}\n",
    "\n",
    "def clean_filename(s):\n",
    "    s = s.replace('/','_').replace('\\\\','_').replace(':','_')\n",
    "    return s\n",
    "    \n",
    "file_handles = {}\n",
    "\n",
    "if len(filenames) > max_images_per_call:\n",
    "    import random\n",
    "    filenames_to_submit = random.sample(filenames,max_images_per_call)\n",
    "else:\n",
    "    filenames_to_submit = filenames\n",
    "    \n",
    "for fn in filenames_to_submit:\n",
    "    file_handles[fn] = (clean_filename(fn), open(fn, 'rb'), 'image/jpeg')\n",
    "\n",
    "m = MultipartEncoder(fields=file_handles)\n",
    "print(m.content_type)\n",
    "\n",
    "r = requests.post(detect_endpoint, \n",
    "                  params=params,\n",
    "                  data=m,\n",
    "                  headers={'Content-Type': m.content_type})\n",
    "\n",
    "print('Status: {}'.format(r.status_code))\n",
    "\n",
    "if not r.ok:\n",
    "    print('Error: {}\\n{}'.format(r.reason,r.text))\n",
    "    results = None\n",
    "else:\n",
    "    results = decoder.MultipartDecoder.from_response(r)\n",
    "    \n",
    "print('Elapsed time: {}'.format(r.elapsed.total_seconds()))\n",
    "\n",
    "for f in file_handles.values():\n",
    "    f[1].close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Decode and print bounding box results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "text_results = {}\n",
    "\n",
    "for part in results.parts:    \n",
    "    # part is a BodyPart object with b'Content-Type', and b'Content-Disposition'; \n",
    "    # the latter includes 'name' and 'filename' info.\n",
    "    headers = {}\n",
    "    for k, v in part.headers.items():\n",
    "        headers[k.decode(part.encoding)] = v.decode(part.encoding)\n",
    "    if headers.get('Content-Type', None) == 'application/json':\n",
    "        assert len(text_results) == 0\n",
    "        text_results = json.loads(part.content.decode())\n",
    "\n",
    "for image_name in text_results.keys():\n",
    "    print('Bounding boxes for {}:'.format(image_name))\n",
    "    print(text_results[image_name])\n",
    "    print('')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Decode and display rendered images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "images = {}\n",
    "for part in results.parts:\n",
    "    # part is a BodyPart object with b'Content-Type', and b'Content-Disposition';\n",
    "    # the latter includes 'name' and 'filename' info\n",
    "    headers = {}\n",
    "    for k, v in part.headers.items():\n",
    "        headers[k.decode(part.encoding)] = v.decode(part.encoding)\n",
    "    if headers.get('Content-Type', None) == 'image/jpeg':\n",
    "        c = headers.get('Content-Disposition')\n",
    "        image_name = c.split('name=\"')[1].split('\"')[0]  # somehow all the filename and name info is all in one string with no obvious forma\n",
    "        image = Image.open(io.BytesIO(part.content))        \n",
    "        images[image_name] = image\n",
    "    \n",
    "    elif headers.get('Content-Type', None) == 'application/json':\n",
    "        text_result = json.loads(part.content.decode())\n",
    "\n",
    "for img_name, img in sorted(images.items()):\n",
    "    print(img_name)\n",
    "    display(img)\n",
    "    print('')"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
